为了让人工神经元模型去实现一些具体的功能,我们在后续设计了一个激活函数,使得人工神经元具备二元线性分类的能力,这套新的模型被称为“感知机(perceptron)“


下面给出感知机的训练算法:
输入:数据集 D {(x1, y1),(x2, y2),...,(xn, yn)},最大训练次数 E
- 初始化:w ← 0 , e ← 0, k ← 0
- For e to E do
- 感知机计算数据集 D
- i ← 0
- while i < n do
- ŷi ← -wTk xi
- if ŷi yi <= 0 then
- wk+1 ← wk + yixi
输出:wk
下面给出案例,以及可供参考的使用Java语言完成训练的有关代码:


import java.util.Arrays;
// 数据点类,包含特征向量 x 和标签 y
class DataPoint {
public double[] x; // 特征向量(3维)
public int y; // 标签(+1 或 -1)
public DataPoint(double[] x, int y) {
this.x = x;
this.y = y;
}
}
// 感知机类,涵盖初始化、预测、训练、获取四个方法
class Perceptron {
public double[] weights; // 权重向量(3维)
public int iterations; // 迭代次数
// 初始化函数:初始权重(0,0,0)
public Perceptron(int dimension) {
this.weights = new double[dimension]; // 初始化权重为0
Arrays.fill(this.weights, 0); // 初始权重为 (0, 0, 0)
this.iterations = 0;
}
// 预测函数:使用 sign(w^T x) 预测标签
public int predict(double[] x) {
double dotProduct = 0;
//用for循环计算向量点乘
for (int i = 0; i < x.length; i++) {
dotProduct += weights[i] * x[i];
}
return (dotProduct > 0) ? 1 : -1;
}
// 训练函数:更新权重
public void train(DataPoint point) {
int prediction = predict(point.x);
// 如果预测错误(y * (w^T x) <= 0),更新权重
if (point.y * prediction <= 0) {
for (int i = 0; i < weights.length; i++) {
weights[i] += point.y * point.x[i]; // w = w + y * x
}
iterations++;
System.out.println("更新权重(迭代 " + iterations + "):
w = " + Arrays.toString(weights));
} else {
System.out.println("正确分类,无需更新:w = " +
Arrays.toString(weights));
}
}
// 获取当前权重,保留两位小数
public double[] getWeights() {
double[] formattedWeights = new double[weights.length];
for (int i = 0; i < weights.length; i++) {
formattedWeights[i] = Double.parseDouble
(String.format("%.2f", weights[i]));
}
return formattedWeights;
}
}
public class PerceptronExample {
public static void main(String[] args) {
// 定义数据集(根据表格数据)
DataPoint[] data = new DataPoint[12];
data[0] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
data[1] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
data[2] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
data[3] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
data[4] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
data[5] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
data[6] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
data[7] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
data[8] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
data[9] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
data[10] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
data[11] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
// 创建感知机实例(3维特征)
Perceptron perceptron = new Perceptron(3);
// 训练模型:遍历数据集
System.out.println("开始训练,初始权重:w = " +
Arrays.toString(perceptron.getWeights()));
for (int i = 0; i < data.length; i++) {
System.out.println("\n处理数据点 " + (i + 1) + ": x =
" + Arrays.toString(data[i].x) + ", y = " + data[i].y);
perceptron.train(data[i]);
// 预测当前点,验证结果
int prediction = perceptron.predict(data[i].x);
System.out.println("预测结果:y = " + prediction);
// 输出最终权重
System.out.println("\n训练完成,最终权重:w = " +
Arrays.toString(perceptron.getWeights()));
}
}
}